/*
 * Created on Jul 12, 2008
 */
package edu.swri.swiftvis.plot.p3d.raytrace;

import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Graphics2D;
import java.awt.GridLayout;
import java.awt.Image;
import java.awt.geom.Rectangle2D;
import java.awt.image.BufferedImage;
import java.util.ArrayList;
import java.util.List;

import javax.swing.JComponent;
import javax.swing.JPanel;

import edu.swri.swiftvis.plot.PlotArea3D;
import edu.swri.swiftvis.plot.p3d.AmbientLight;
import edu.swri.swiftvis.plot.p3d.Basic;
import edu.swri.swiftvis.plot.p3d.Camera3D;
import edu.swri.swiftvis.plot.p3d.DirectionalLight;
import edu.swri.swiftvis.plot.p3d.DrawAttributes;
import edu.swri.swiftvis.plot.p3d.PointLight;
import edu.swri.swiftvis.plot.p3d.RenderEngine;
import edu.swri.swiftvis.plot.p3d.Sphere3D;
import edu.swri.swiftvis.plot.p3d.Triangle3D;
import edu.swri.swiftvis.plot.p3d.Vect3D;
import edu.swri.swiftvis.util.EditableDouble;
import edu.swri.swiftvis.util.EditableInt;
import edu.swri.swiftvis.util.LoopBody;
import edu.swri.swiftvis.util.ThreadHandler;

public class RayTraceEngine implements RenderEngine {
    public RayTraceEngine(PlotArea3D pa) {
        plotArea=pa;
    }
    
    public String toString() {
        return "Ray Tracing Engine";
    }

    public void add(Sphere3D s,DrawAttributes attrs) {
        RTSphere sphere=new RTSphere(s.getCenter(),s.getRadius());
        sphere.setAttributes(attrs);
        geomList.add(sphere);
    }

    public void add(Triangle3D t,DrawAttributes[] attrs) {
        RTTriangle tri=new RTTriangle(t);
        tri.setAttributes(attrs);
        geomList.add(tri);
    }

    public void addPointLight(PointLight pl) {
        pointLights.add(pl);
    }

    public void addDirectionalLight(DirectionalLight dl) {
        directLights.add(dl);
    }

    public void setAmbientLight(AmbientLight al) {
        ambientLight=al;
    }

    public void clearScene() {
        pointLights.clear();
        directLights.clear();
        geomList.clear();
    }

    public Image sceneComplete(int imgWidth, int imgHeight) {
        sx=imgWidth;
        sy=imgHeight;
        double minx=1e100,maxx=-1e100,miny=1e100,maxy=-1e100,minz=1e100,maxz=-1e100;
        for(RTGeometry g:geomList) {
            RTSphere bs=g.getBoundingSphere();
            double r=bs.getRadius();
            if(bs.getX()-r<minx) minx=bs.getX()-r;
            if(bs.getX()+r>maxx) maxx=bs.getX()+r;
            if(bs.getY()-r<miny) miny=bs.getY()-r;
            if(bs.getY()+r>maxy) maxy=bs.getY()+r;
            if(bs.getZ()-r<minz) minz=bs.getZ()-r;
            if(bs.getZ()+r>maxz) maxz=bs.getZ()+r;
        }
        double size=Math.max(maxx-minx,Math.max(maxy-miny,maxz-minz));
        tree=new Octree(0.5*(maxx+minx),0.5*(maxy+miny),0.5*(maxz+minz),size);
        for(RTGeometry g:geomList) {
            tree.addObject(g);
        }
        tree.getRoot().finalize();
        renderImage();
        return img;
    }
    
    public Image cameraMoved() {
        renderImage();
        return img;
    }
    
    public JComponent getPropertiesPanel() {
        if(propPanel==null) {
            propPanel=new JPanel(new BorderLayout());
            JPanel northPanel=new JPanel(new GridLayout(5,1));
            northPanel.add(raysPerPixel.getLabeledTextField("Rays Per Pixel",null));
            northPanel.add(scale.getLabeledTextField("Lighting Length Scale",null));
            northPanel.add(maxRecursions.getLabeledTextField("Maximum Recursions",null));
            northPanel.add(scatterLightRays.getLabeledTextField("Scattered Rays",null));
            northPanel.add(viewSize.getLabeledTextField("View Size",null));
            propPanel.add(northPanel,BorderLayout.NORTH);
        }
        return propPanel;
    }

    private void renderImage() {
        if(tree==null || sx<=0 || sy<=0) return;
        for(PointLight pl:pointLights) pl.fixPosition(plotArea.getSink());
        for(DirectionalLight dl:directLights) dl.fixDirection(plotArea.getSink());
        setupView();
        if(img==null || img.getWidth()!=sx || img.getHeight()!=sy) {
            img=new BufferedImage(sx,sy,BufferedImage.TYPE_INT_ARGB);
        }
        Graphics2D g=img.createGraphics();
        g.setPaint(Color.black);
        g.fill(new Rectangle2D.Double(0,0,img.getWidth(),img.getHeight()));
        double[] toEye={eye[0]-screenCenter[0],eye[1]-screenCenter[1],eye[2]-screenCenter[2]};
        final double dist=Basic.distance(screenCenter,eye);
        Basic.normalize(up);
        final double[] left=Basic.crossProduct(toEye,up);
        Basic.normalize(left);
        final int minSize=Math.min(sx,sy);
        ThreadHandler.instance().dynamicForLoop(null, 0, img.getWidth(), new LoopBody() {
            public void execute(int i) {
                for(int j=0; j<img.getHeight(); ++j) {
                    double[] loc=new double[3];
                    FloatColor col=new FloatColor(0,0,0);
                    float factor=1.0f/raysPerPixel.getValue();
                    for(int k=0; k<raysPerPixel.getValue(); ++k) {
                        double ri=0;
                        double rj=0;
                        if(k>0) {
                            ri=Math.random();
                            rj=Math.random();
                        }
                        double u=viewSize.getValue()*((i+ri)-0.5*img.getWidth())/minSize;
                        double v=viewSize.getValue()*((j+rj)-0.5*img.getHeight())/minSize;
                        loc[0]=screenCenter[0]-dist*left[0]*u-dist*up[0]*v;
                        loc[1]=screenCenter[1]-dist*left[1]*u-dist*up[1]*v;
                        loc[2]=screenCenter[2]-dist*left[2]*u-dist*up[2]*v;
                        Ray r=new Ray(eye,loc);
                        col.add(rayColor(r,0),factor);
                    }
                    img.setRGB(i,j,col.getRGB());
                }
            }
        });
    }
    
    private FloatColor rayColor(Ray r,int cnt) {
        if(cnt>=maxRecursions.getValue()) return new FloatColor(0,0,0);
        RayFollower rf=new RayFollower(r,tree);
        //System.out.println("Ray "+r);
        if(rf.follow()) {
            RTGeometry geom=rf.getIntersectObject();
            //System.out.println("Hit "+geom);
            double[] intersect=r.point(rf.getIntersectTime()*0.999999);
            //System.out.println("Hit val "+rf.getIntersectTime()+" point "+intersect[0]+" "+intersect[1]+" "+intersect[2]);
            double[] norm=geom.getNormal(intersect);
            Ray normRay=new Ray(new double[]{0,0,0},norm);
            Basic.normalize(norm);
            float[] ret=ambientLight.getColor().getRGBColorComponents(null);
            DrawAttributes attr=geom.getAttributes(intersect);
            float[] col=attr.getColor().getRGBColorComponents(null);
            ret[0]*=col[0];
            ret[1]*=col[1];
            ret[2]*=col[2];
            for(int i=0; i<pointLights.size(); ++i) {
                Ray toLight=new Ray(intersect,pointLights.get(i).getPosition().get());
                RayFollower followLight=new RayFollower(toLight,tree);
                if(!followLight.follow() || followLight.getIntersectTime()<=1) {
                    float lightMag=(float)(toLight.dotProduct(normRay)/toLight.mainLength());
                    if(lightMag>0) {
                        float[] lightColor=pointLights.get(i).getColor().getRGBColorComponents(null);
                        lightMag/=(toLight.mainLength()/scale.getValue()+1);
                        ret[0]+=lightColor[0]*lightMag*col[0];
                        ret[1]+=lightColor[1]*lightMag*col[1];
                        ret[2]+=lightColor[2]*lightMag*col[2];
                    }
                }
            }
            for(int i=0; i<directLights.size(); ++i) {
                Ray toLight=new Ray(intersect,new Vect3D(intersect).minus(directLights.get(i).getDirection().mult(10000)).get());
                RayFollower followLight=new RayFollower(toLight,tree);
                if(!followLight.follow()) {
                    float lightMag=(float)(toLight.dotProduct(normRay)/toLight.mainLength());
                    if(lightMag>0) {
                        float[] lightColor=directLights.get(i).getColor().getRGBColorComponents(null);
                        ret[0]+=lightColor[0]*lightMag*col[0];
                        ret[1]+=lightColor[1]*lightMag*col[1];
                        ret[2]+=lightColor[2]*lightMag*col[2];
                    }
                }
            }
            float opacity=(float)attr.getOpacity();
            float reflectivity=(float)attr.getReflectivity();
            ret[0]*=opacity;
            ret[1]*=opacity;
            ret[2]*=opacity;
            if(reflectivity>0) {
                double[] dir=r.getDirectionVector();
                double mag=dir[0]*norm[0]+dir[1]*norm[1]+dir[2]*norm[2];
                double[] one=new double[intersect.length];
                for(int i=0; i<one.length; ++i) {
                    one[i]=intersect[i]+dir[i]-2*mag*norm[i];
                }
                Ray ref=new Ray(intersect,one);
                FloatColor fc=rayColor(ref,cnt+1);
                ret[0]+=reflectivity*fc.getRed();
                ret[1]+=reflectivity*fc.getGreen();
                ret[2]+=reflectivity*fc.getBlue();
            }
            if(opacity<1) {
                double[] start=r.point(rf.getIntersectTime()*1.00001);
                double[] dir=r.getDirectionVector();
                double[] one=new double[start.length];
                for(int i=0; i<one.length; ++i) {
                    one[i]=start[i]+dir[i];
                }
                Ray thru=new Ray(start,one);
                FloatColor fc=rayColor(thru,cnt+1);
                ret[0]+=fc.getRed()*(1-opacity);
                ret[1]+=fc.getGreen()*(1-opacity);
                ret[2]+=fc.getBlue()*(1-opacity);
            }
            for(int i=0; i<scatterLightRays.getValue();) {
                double[] dir={Math.random()-0.5,Math.random()-0.5,Math.random()-0.5};
                if(Basic.dotProduct(norm,dir)>0) {
                    dir[0]+=intersect[0];
                    dir[1]+=intersect[1];
                    dir[2]+=intersect[2];
                    Ray sray=new Ray(intersect,dir);
                    FloatColor fc=rayColor(sray,cnt+3);
                    ret[0]+=fc.getRed()/scatterLightRays.getValue();
                    ret[1]+=fc.getGreen()/scatterLightRays.getValue();
                    ret[2]+=fc.getBlue()/scatterLightRays.getValue();
                    i++;
                }
            }
            return new FloatColor(ret[0],ret[1],ret[2]);
        }
        return new FloatColor(0,0,0);
    }
    
    private void setupView() {
        Camera3D cam=plotArea.getCamera();
        eye=cam.getCenter(plotArea.getSink()).get();
        up=cam.getUp(plotArea.getSink()).get();
        screenCenter=(cam.getCenter(plotArea.getSink()).plus(cam.getForward(plotArea.getSink()))).get();
    }

    private PlotArea3D plotArea;
    
    private List<RTGeometry> geomList=new ArrayList<RTGeometry>();
    
    private int sx;
    private int sy;
    private BufferedImage img;
    private Octree tree;
    private List<PointLight> pointLights=new ArrayList<PointLight>();
    private List<DirectionalLight> directLights=new ArrayList<DirectionalLight>();
    private AmbientLight ambientLight=new AmbientLight(new Color(20,20,20));
    private EditableInt raysPerPixel=new EditableInt(1);
    private EditableDouble scale=new EditableDouble(100);
    private EditableInt scatterLightRays=new EditableInt(0);
    private EditableInt maxRecursions=new EditableInt(5);
    private EditableDouble viewSize=new EditableDouble(1.0);
    private double[] screenCenter=new double[3];
    private double[] eye=new double[3];
    private double[] up=new double[3];
    
    private transient JComponent propPanel;
    
    private static class FloatColor {
        public FloatColor(float r,float g,float b) {
            red=r;
            green=g;
            blue=b;
        }
        public void add(FloatColor fc,float factor) {
            red+=fc.red*factor;
            green+=fc.green*factor;
            blue+=fc.blue*factor;
        }
        public int getRGB() {
            float r=red,g=green,b=blue;
            if(r<0) r=0;
            if(r>1) r=1;
            if(g<0) g=0;
            if(g>1) g=1;
            if(b<0) b=0;
            if(b>1) b=1;
            return new Color(r,g,b).getRGB();
        }
        public float getRed() { return red; }
        public float getGreen() { return green; }
        public float getBlue() { return blue; }
        private float red,green,blue;
    }

}
